% For "The Climate Risk Premium" (Lemoine, JAERE, 2020)

% Epstein-Zin value function calculations and then scc calculation

% periods to simulate
max_time = Params.horizon + Params.years_for_scc;

% error check
if ~isfield(Params,'nodes')
    Params.nodes = Params.basiscoeffs;
end
if ~isfield(Params,'nodetype')
    Params.nodetype = Params.basistype;
end
if ~isfield(Params,'quadrature_nodes_consvol')
    Params.quadrature_nodes_consvol = Params.quadrature_nodes;
end
if ~isfield(Params,'quadrature_nodes_consvol_statespace')
    Params.quadrature_nodes_consvol_statespace = Params.quadrature_nodes_consvol;
end
if Params.quadrature_nodes_consvol_statespace < Params.quadrature_nodes_consvol
    warning('Fewer quadrature nodes in state space simulation than in welfare evaluation');
end
if ~isfield(Params,'uselogstates')
    Params.uselogstates = 1;
end
if ~isfield(Params,'resetjumps')
    Params.resetjumps = 0;
end


%% Obtain total population, for a normalization

totalpop = 0;
for t=1:max_time
    totalpop = totalpop + Params.pop(t); % cannot pass a vector into Params.pop
end

        
%% Initial state, for eventual calculations and potentially for checking grid along the way

if Params.uselogstates==1
    start_state = [ log(Params.cons_start/Params.pop(1)) log(Params.M1_start) log(Params.M2_start) Params.temp_start ];
else
    start_state = [ Params.cons_start/Params.pop(1) Params.M1_start Params.M2_start Params.temp_start ];
end


%% Set up outer quadrature nodes for warming and damages based on assumptions about learning

% make into vectors
Params.cons_damage = squeeze(Params.cons_damage);
Params.warming = squeeze(Params.warming);

if Params.learn_ez==1
    quadnodes_outer_dam = length(Params.cons_damage);
    quadnodes_outer_warm = length(Params.warming);
else
    quadnodes_outer_dam = 1;
    quadnodes_outer_warm = 1;
end


%% Preallocate

if Params.learn_ez==1
    % just store this period so doesn't get huge
    v_coeffs = zeros(1, prod(Params.basiscoeffs), quadnodes_outer_warm, quadnodes_outer_dam);
    if Params.perturbation>0
        v_coeffs_perturb = zeros(1, prod(Params.basiscoeffs), quadnodes_outer_warm, quadnodes_outer_dam);
    end
else
    % store all periods
    v_coeffs = zeros(max_time-1, prod(Params.basiscoeffs), quadnodes_outer_warm, quadnodes_outer_dam);
    if Params.perturbation>0
        v_coeffs_perturb = zeros(max_time-1, prod(Params.basiscoeffs), quadnodes_outer_warm, quadnodes_outer_dam);
    end
end


%% Loop over quadrature nodes for climate sensitivity and damages

ParObj = parpool(Params.workers);

Params.cons_damage_original = Params.cons_damage;
Params.warming_original = Params.warming;
if Params.learn_ez~=1
    start_time = 1;
else
    start_time = 2; % do not need first period because will do that after have quadrature nodes
end

for index_node_outer_dam = 1:quadnodes_outer_dam
    for index_node_outer_warm = 1:quadnodes_outer_warm
        
        disp(['Starting node_outer_warm-dam ' num2str(index_node_outer_warm) ',' num2str(index_node_outer_dam) ' out of ' num2str(quadnodes_outer_dam*quadnodes_outer_warm)]);
        
        Params.cons_damage = Params.cons_damage_original;
        Params.warming = Params.warming_original;
        
                
        %% Quadrature nodes for consumption volatility, for state space bounds
        
        if strcmp(option_consvolatility,'on')
            [Wiener.cons_increment,weights_cons] = qnwnorm(Params.quadrature_nodes_consvol_statespace,0,1);
            Wiener.cons_increment = repmat(Wiener.cons_increment,1,max_time);
            Wiener.temp_increment = zeros(1,max_time);
        else
            Wiener.cons_increment = zeros(1,max_time);
            weights_cons = 1;
            Wiener.temp_increment = zeros(1,max_time);
        end
        
        
        
        %% Simulate to obtain state space bounds
                
        % store original values
        Wiener.cons_increment_original = Wiener.cons_increment;
                
        % Simulate at low-dam, high-cs, high-growth combo (for upper bounds on all states)
        if Params.learn_ez==1
            Params.cons_damage = Params.cons_damage(index_node_outer_dam);
            Params.warming = Params.warming(index_node_outer_warm);
        else
            Params.cons_damage = min(Params.cons_damage_original);
            Params.warming = min(Params.warming); % this is inverse of climate sensitivity
        end
        Wiener.cons_increment = max(Wiener.cons_increment,[],1);
        % initialize
        Paths.M1 = zeros(1,max_time-1);
        Paths.M2 = zeros(1,max_time-1);
        Paths.temp = zeros(1,max_time-1);
        Paths.cons = zeros(1,max_time-1);
        Paths.cons_pc = zeros(1,max_time-1);
        counter_fine = 0;
        % loop through time
        for index_time = 1:max_time-1
            
            % Starting time
            if index_time==1
                Paths.M1(:,index_time) = Params.M1_start;
                Paths.M2(:,index_time) = Params.M2_start;
                Paths.temp(:,index_time) = Params.temp_start;
                Paths.cons(:,index_time) = Params.cons_start;
                Paths.cons_pc(:,index_time) = Params.cons_start/Params.pop(index_time);
            end
            
            % simulate stochastic processes
            if index_time < max_time-1
                
                M1_value = Paths.M1(:,index_time);
                M2_value = Paths.M2(:,index_time);
                temp_value = Paths.temp(:,index_time);
                cons_value = Paths.cons(:,index_time);
                
                for index_interval = 1:1/Params.timestep
                    
                    counter_fine = counter_fine + 1;
                    
                    % Calculate changes in variables
                    run sub_transitions;
                    
                    % Update values of variables
                    M1_value = M1_value + M1change;
                    M2_value = M2_value + M2change;
                    temp_value = temp_value + tempchange;
                    cons_value = cons_value + conschange;
                    
                end % end stochastic process simulation
                
                % store next period's value
                Paths.M1(:,index_time+1) = M1_value;
                Paths.M2(:,index_time+1) = M2_value;
                Paths.temp(:,index_time+1) = temp_value;
                Paths.cons(:,index_time+1) = cons_value;
                Paths.cons_pc(:,index_time+1) = cons_value/Params.pop(index_time+1);
                
                clear M1_value M2_value temp_value cons_value;
                clear M1change M2change forcing tempchange conschange;
                
            end
            
        end % end time looping
        % store bounds
        temp_M1 = Paths.M1;
        temp_M2 = Paths.M2;
        temp_T = Paths.temp;
        temp_cons_pc = Paths.cons_pc;
        
        
        % Simulate at high-dam, high-cs, low-growth combo (for lower bound on cons per capita and thus also on M1 and M2)
        if Params.learn_ez~=1
            Params.cons_damage = max(Params.cons_damage_original);
            Params.warming = min(Params.warming_original); % this is inverse of climate sensitivity
        end
        Wiener.cons_increment = min(Wiener.cons_increment_original,[],1);
        % initialize
        Paths.M1 = zeros(1,max_time-1);
        Paths.M2 = zeros(1,max_time-1);
        Paths.temp = zeros(1,max_time-1);
        Paths.cons = zeros(1,max_time-1);
        Paths.cons_pc = zeros(1,max_time-1);
        counter_fine = 0;
        % loop through time
        for index_time = 1:max_time-1
            
            % Starting time
            if index_time==1
                Paths.M1(:,index_time) = Params.M1_start;
                Paths.M2(:,index_time) = Params.M2_start;
                Paths.temp(:,index_time) = Params.temp_start;
                Paths.cons(:,index_time) = Params.cons_start;
                Paths.cons_pc(:,index_time) = Params.cons_start/Params.pop(index_time);
            end
            
            % simulate stochastic processes
            if index_time < max_time-1
                
                M1_value = Paths.M1(:,index_time);
                M2_value = Paths.M2(:,index_time);
                temp_value = Paths.temp(:,index_time);
                cons_value = Paths.cons(:,index_time);
                
                for index_interval = 1:1/Params.timestep
                    
                    counter_fine = counter_fine + 1;
                    
                    % Calculate changes in variables
                    run sub_transitions;
                    
                    % Update values of variables
                    M1_value = M1_value + M1change;
                    M2_value = M2_value + M2change;
                    temp_value = temp_value + tempchange;
                    cons_value = cons_value + conschange;
                    
                end % end stochastic process simulation
                
                % store next period's value
                Paths.M1(:,index_time+1) = M1_value;
                Paths.M2(:,index_time+1) = M2_value;
                Paths.temp(:,index_time+1) = temp_value;
                Paths.cons(:,index_time+1) = cons_value;
                Paths.cons_pc(:,index_time+1) = cons_value/Params.pop(index_time+1);
                
                clear M1_value M2_value temp_value cons_value;
                clear M1change M2change forcing tempchange conschange;
                
            end
            
        end % end time looping
        
        % create bounds on each dimension of grid in each period, ordered min then
        % max
        statespace_multiplier = 1.05+(1/(max_time-1))*[1:max_time-1]';
        M1_bounds = [ min(transpose(Paths.M1),[],2) max(transpose(temp_M1),[],2) ];
        M2_bounds = [ min(transpose(Paths.M2),[],2) max(transpose(temp_M2),[],2) ];
        T_bounds = [ 0.01*(Params.temp_start*ones(max_time-1,1)) max(1.5*Params.temp_start, max(transpose(temp_T),[],2)) ]; % don't have clear lower bound on temp from Paths
        Conspc_bounds = [ min(transpose(Paths.cons_pc),[],2) max(transpose(temp_cons_pc),[],2) ];
        if Params.uselogstates==1
            statemin(:,:,index_node_outer_warm,index_node_outer_dam) = (1./statespace_multiplier).*[ log(Conspc_bounds(:,1)) log(M1_bounds(:,1)) log(M2_bounds(:,1)) T_bounds(:,1) ];
            statemax(:,:,index_node_outer_warm,index_node_outer_dam) = statespace_multiplier.*[ log(Conspc_bounds(:,2)) log(M1_bounds(:,2)) log(M2_bounds(:,2)) T_bounds(:,2) ];
        else
            statemin(:,:,index_node_outer_warm,index_node_outer_dam) = (1./statespace_multiplier).*[ Conspc_bounds(:,1) M1_bounds(:,1) M2_bounds(:,1) T_bounds(:,1) ];
            statemax(:,:,index_node_outer_warm,index_node_outer_dam) = statespace_multiplier.*[ Conspc_bounds(:,2) M1_bounds(:,2) M2_bounds(:,2) T_bounds(:,2) ];
        end
        
        %  make sure state space is monotonically widening (usually true,
        %  unless at very high damage, very high warming node)
        for index_time = 2:max_time-1
            statemin(index_time,:,index_node_outer_warm,index_node_outer_dam) = min(statemin(index_time,:,index_node_outer_warm,index_node_outer_dam),statemin(index_time-1,:,index_node_outer_warm,index_node_outer_dam));
            statemax(index_time,:,index_node_outer_warm,index_node_outer_dam) = max(statemax(index_time,:,index_node_outer_warm,index_node_outer_dam),statemax(index_time-1,:,index_node_outer_warm,index_node_outer_dam));
        end
        

        clear Paths;
        
        clear temp_M1 temp_M2 temp_T temp_cons_pc;
        
        % Restore original values
        Params.cons_damage = Params.cons_damage_original;
        Params.warming = Params.warming_original;
        clear Wiener.cons_increment_original;
        
        
        %% Quadrature nodes for consumption volatility, for welfare evaluation
        
        if strcmp(option_consvolatility,'on')
            [Wiener.cons_increment,weights_cons] = qnwnorm(Params.quadrature_nodes_consvol,0,1);
            Wiener.cons_increment = repmat(Wiener.cons_increment,1,max_time);
            Wiener.temp_increment = zeros(1,max_time);
        else
            Wiener.cons_increment = zeros(1,max_time);
            weights_cons = 1;
            Wiener.temp_increment = zeros(1,max_time);
        end
        
        
        
        %% Step backwards through value function               
        
        % store values so can make some nodes deterministic
        Wiener.cons_increment_original = Wiener.cons_increment;
        weights_cons_original = weights_cons;
        testmax = [];
        testmin = [];
                       
        if Params.learn_ez==1
            Params.cons_damage = Params.cons_damage(index_node_outer_dam);
            Params.warming = Params.warming(index_node_outer_warm);
        else
            Params.cons_damage = reshape(Params.cons_damage,[1 1 size(nodes_damage,1) 1]);
            Params.warming = reshape(Params.warming,[1 size(nodes_climsens,1) 1]);
            if index_node_outer_dam==1 && index_node_outer_warm==1
                weights_damage = reshape(weights_damage,[1 1 length(weights_damage)]);
                weights_climsens = reshape(weights_climsens,[1 length(weights_climsens) 1]);
            end
        end
                
        statematrix_next = [];
        fspace_next = [];
        coeffs_next = [];
        coeffs_perturb_next = [];
        
        tic;        
                
        for t = max_time-1:-1:start_time
            
            disp(['Value function recursion: time ' num2str(t)]);
            
            % Create nodes
            fspace_now = fundefn(Params.basistype,Params.basiscoeffs,statemin(t,:,index_node_outer_warm,index_node_outer_dam),statemax(t,:,index_node_outer_warm,index_node_outer_dam));
            fspace_nodes = fundefn(Params.nodetype,Params.nodes,statemin(t,:,index_node_outer_warm,index_node_outer_dam),statemax(t,:,index_node_outer_warm,index_node_outer_dam));
            nodes = funnode(fspace_nodes);
            statematrix = gridmake(nodes); % each row defines one point, where the columns give [cons_pc M1 M2 T]
            
            % check whether any are beyond continuation period's nodes and
            % change them if so
            if t < max_time-1
                % loop through dimensions, figuring out how many are beyond
                % edge nodes of next period's grid and resetting them to
                % something between that edge and nearest current interior nodes
                for index_dim = 1:size(statematrix,2)
                    if index_dim==3 % skip M2 because this carbon stock can grow over time and want allow trajectory to move with it
                        continue;
                    end
                    currentnodes = unique(statematrix(:,index_dim),'sorted');
                    % take care of ones that are too large
                    if sum(currentnodes > max(statematrix_next(:,index_dim))) > 0
                        currentedge = max(currentnodes( currentnodes < max(statematrix_next(:,index_dim)) )); % find largest node that is below next period's node
                        nodestochange = currentnodes( currentnodes >= max(statematrix_next(:,index_dim)) );
                        nodesnew = currentedge + ( max(statematrix_next(:,index_dim)) - currentedge )*[length(nodestochange):-1:1]/length(nodestochange);
                        for index_reset=1:length(nodestochange)
                            statematrix( statematrix(:,index_dim)==nodestochange(index_reset) , index_dim ) = nodesnew(index_reset);
                        end
                        disp(['Resetting ' num2str(length(nodestochange)) ' nodes that were too high, out of ' num2str(length(currentnodes)) ' nodes in dimension ' num2str(index_dim)]);
                    end
                    % take care of ones that are too small
                    if sum(currentnodes < min(statematrix_next(:,index_dim))) > 0
                        currentedge = min(currentnodes( currentnodes > min(statematrix_next(:,index_dim)) )); % find smallest node that is above next period's node
                        nodestochange = currentnodes( currentnodes <= min(statematrix_next(:,index_dim)) );
                        nodesnew = currentedge - ( currentedge - min(statematrix_next(:,index_dim)) )*[1:length(nodestochange)]/length(nodestochange);
                        for index_reset=1:length(nodestochange)
                            statematrix( statematrix(:,index_dim)==nodestochange(index_reset) , index_dim ) = nodesnew(index_reset);
                        end
                        disp(['Resetting ' num2str(length(nodestochange)) ' nodes that were too low, out of ' num2str(length(currentnodes)) ' nodes in dimension ' num2str(index_dim)]);
                    end
                end
                if sum(start_state([1 2 4]) < min(statematrix(:,[1 2 4]),[],1)) > 0 || sum(start_state([1 2 4]) > max(statematrix(:,[1 2 4]),[],1)) > 0
                    error('Starting state is outside of grid at time %d in some dim other than M2, with number of nodes outside being %s',t,mat2str([start_state < min(statematrix,[],1); start_state > max(statematrix,[],1)]));
                end
                clear index_dim index_reset;
            end
            
            
            parfor (index_node = 1:size(statematrix,1),Params.workers)                
                
                if Params.uselogstates==1
                    cons_pc_node = exp(statematrix(index_node,1));
                else
                    cons_pc_node = statematrix(index_node,1);
                end                
                
                % obtain continuation value
                do_perturb = 0;                
                [next_value] = contvalue( do_perturb, coeffs_next, fspace_next, index_node, t, max_time, statematrix, statematrix_next, cons_pc_node, weights_damage, weights_climsens, weights_cons, totalpop, statemin(:,:,index_node_outer_warm,index_node_outer_dam), Wiener, Params, option_ocean, option_damagetype);
                
                % calculate value by combining with current consumption pc
                v(index_node,1) = ( (1-exp(-Params.discount))*(Params.pop(t)/totalpop)*cons_pc_node^(1-Params.time) +  exp(-Params.discount)*next_value^((1-Params.time)/(1-Params.risk))  )^(1/(1-Params.time));
                
                if ~isreal(v(index_node,1))
                    error('v(index_node,1) is complex')
                end
                if v(index_node,1)<=0
                    error('v(index_node,1) is negative')
                end
                
                if Params.perturbation>0
                    
                    % obtain continuation value
                    do_perturb = 1;
                    [next_value] = contvalue( do_perturb, coeffs_perturb_next, fspace_next, index_node, t, max_time, statematrix, statematrix_next, cons_pc_node, weights_damage, weights_climsens, weights_cons, totalpop, statemin(:,:,index_node_outer_warm,index_node_outer_dam), Wiener, Params, option_ocean, option_damagetype);
                    
                    % calculate value by combining with current consumption pc
                    v_perturb(index_node,1) = ( (1-exp(-Params.discount))*(Params.pop(t)/totalpop)*cons_pc_node^(1-Params.time) +  exp(-Params.discount)*next_value^((1-Params.time)/(1-Params.risk))  )^(1/(1-Params.time));
                    
                end
                
            end
            
            if t==1
                if sum(start_state < min(statematrix,[],1)) > 0 || sum(start_state > max(statematrix,[],1)) > 0
                    error('Starting state is outside of grid');
                end
            end
            
            
            % fit value function (and store it for use at end in case do
            % decomposition)
            if Params.learn_ez==1                
                v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam) = funfitxy(fspace_now, statematrix, v);
                coeffs_next = v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam);
                if Params.perturbation>0
                    v_coeffs_perturb(1,:,index_node_outer_warm,index_node_outer_dam) = funfitxy(fspace_now, statematrix, v_perturb);
                    coeffs_perturb_next = v_coeffs_perturb(1,:,index_node_outer_warm,index_node_outer_dam);
                end
            else                
                v_coeffs(t,:,index_node_outer_warm,index_node_outer_dam) = funfitxy(fspace_now, statematrix, v);
                coeffs_next = v_coeffs(t,:,index_node_outer_warm,index_node_outer_dam);
                if Params.perturbation>0
                    v_coeffs_perturb(t,:,index_node_outer_warm,index_node_outer_dam) = funfitxy(fspace_now, statematrix, v_perturb);
                    coeffs_perturb_next = v_coeffs_perturb(t,:,index_node_outer_warm,index_node_outer_dam);
                end
            end
            fspace_next = fspace_now;
            statematrix_next = statematrix;
            
            if ~isreal(v_coeffs)
                error('v_coeffs is complex')
            end           
            
        end
        
        elapsedtime_ez(index_otherloop,t,index_node_outer_warm,index_node_outer_dam) = toc;
        disp(['Time spent on value recursion in loop ' num2str(index_otherloop) ', node_outer_warm-dam ' num2str(index_node_outer_warm) ',' num2str(index_node_outer_dam) ': ' num2str(elapsedtime_ez(index_otherloop,t,index_node_outer_warm,index_node_outer_dam)/60) ' minutes.']);

        
    end
end

delete(ParObj); % close parallel pool



%% Calculate scc

if Params.learn_ez==1
            
    Params.cons_damage = reshape(Params.cons_damage_original,[1 1 size(nodes_damage,1) 1]);
    Params.warming = reshape(Params.warming_original,[1 size(nodes_climsens,1) 1]);
    weights_damage = reshape(weights_damage,[1 1 length(weights_damage)]);
    weights_climsens = reshape(weights_climsens,[1 length(weights_climsens) 1]);
    
    if Params.uselogstates==1
        cons_pc_now = exp(start_state(1));
    else
        cons_pc_now = start_state(1);
    end
        
    % initialize
    clear next_value;
    cons_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*cons_pc_now*Params.pop(1); % adjusts for state being per capita
    if Params.uselogstates==1
        M1_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*exp(start_state(2));
        M2_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*exp(start_state(3));
    else
        M1_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(2);
        M2_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(3);
    end
    temp_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(4);
    
    % simulate to get next period states
    counter_fine = 1;    
    for index_interval = 1:1/Params.timestep
        
        % Calculate changes in variables
        index_time = 1;        
        sub_transitions
        
        % Update values of variables
        M1_value = M1_value + M1change;
        M2_value = M2_value + M2change;
        temp_value = temp_value + tempchange;
        cons_value = cons_value + conschange;
        
    end % end stochastic process simulation

    % obtain continuation values
    for index_node_outer_dam = 1:quadnodes_outer_dam
        for index_node_outer_warm = 1:quadnodes_outer_warm
            fspace_next = fundefn(Params.basistype,Params.basiscoeffs,statemin(2,:,index_node_outer_warm,index_node_outer_dam),statemax(2,:,index_node_outer_warm,index_node_outer_dam));
            % obtain matrix of states
            if Params.uselogstates==1
                next_state = [ log(cons_value(:,index_node_outer_warm,index_node_outer_dam)/Params.pop(1+1)) log(M1_value(:,index_node_outer_warm,index_node_outer_dam)) log(M2_value(:,index_node_outer_warm,index_node_outer_dam)) temp_value(:,index_node_outer_warm,index_node_outer_dam) ];
            else
                next_state = [ cons_value(:,index_node_outer_warm,index_node_outer_dam)/Params.pop(1+1) M1_value(:,index_node_outer_warm,index_node_outer_dam) M2_value(:,index_node_outer_warm,index_node_outer_dam) temp_value(:,index_node_outer_warm,index_node_outer_dam) ];
            end
            next_value(:,index_node_outer_warm,index_node_outer_dam) = funeval(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)',fspace_next,next_state);
        end
    end
    next_value = next_value.^(1-Params.risk);
    % take expectations over damage nodes
    next_value = bsxfun(@times,next_value,weights_damage);
    next_value = sum(next_value,3);
    % take expectations over climate sensitivity nodes
    next_value = bsxfun(@times,next_value,weights_climsens);
    next_value = sum(next_value,2);
    % take expectations over consumption volatility
    next_value = weights_cons'*next_value;
    % calculate value by combining with current consumption pc
    v_start = ( (1-exp(-Params.discount))*(Params.pop(1)/totalpop)*cons_pc_now^(1-Params.time) +  exp(-Params.discount)*next_value^((1-Params.time)/(1-Params.risk))  )^(1/(1-Params.time));
    
    % do all again under perturbation
    if Params.perturbation>0
        clear next_value;
        M1_store = M1_value;
        M2_store = M2_value;
        cons_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*cons_pc_now*Params.pop(1); % adjusts for state being per capita
        if Params.uselogstates==1
            M1_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*exp(start_state(2)) + Params.perturbation*Params.M1_frac;
            M2_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*exp(start_state(3)) + Params.perturbation*Params.M2_frac*exp(-Params.co2_decay*(1-1));
        else
            M1_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(2) + Params.perturbation*Params.M1_frac;
            M2_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(3) + Params.perturbation*Params.M2_frac*exp(-Params.co2_decay*(1-1));
        end
        temp_value = ones(size(Wiener.cons_increment,1),length(Params.warming),length(Params.cons_damage))*start_state(4);
        counter_fine = 1;
        for index_interval = 1:1/Params.timestep
            
            % Calculate changes in variables
            index_time = 1;
            sub_transitions
            
            % Update values of variables
            M1_value = M1_value + M1change;
            M2_value = M2_value + M2change;
            temp_value = temp_value + tempchange;
            cons_value = cons_value + conschange;
            
        end
        % restore CO2 trajectory
        M1_value = M1_store;
        M2_value = M2_store;
        % obtain continuation values
        for index_node_outer_dam = 1:quadnodes_outer_dam
            for index_node_outer_warm = 1:quadnodes_outer_warm
                fspace_next = fundefn(Params.basistype,Params.basiscoeffs,statemin(2,:,index_node_outer_warm,index_node_outer_dam),statemax(2,:,index_node_outer_warm,index_node_outer_dam));
                % obtain matrix of states
                if Params.uselogstates==1
                    next_state = [ log(cons_value(:,index_node_outer_warm,index_node_outer_dam)/Params.pop(1+1)) log(M1_value(:,index_node_outer_warm,index_node_outer_dam)) log(M2_value(:,index_node_outer_warm,index_node_outer_dam)) temp_value(:,index_node_outer_warm,index_node_outer_dam) ];
                else
                    next_state = [ cons_value(:,index_node_outer_warm,index_node_outer_dam)/Params.pop(1+1) M1_value(:,index_node_outer_warm,index_node_outer_dam) M2_value(:,index_node_outer_warm,index_node_outer_dam) temp_value(:,index_node_outer_warm,index_node_outer_dam) ];
                end
                next_value(:,index_node_outer_warm,index_node_outer_dam) = funeval(v_coeffs_perturb(1,:,index_node_outer_warm,index_node_outer_dam)',fspace_next,next_state);
            end
        end
        next_value = next_value.^(1-Params.risk);
        % take expectations over damage nodes
        next_value = bsxfun(@times,next_value,weights_damage);
        next_value = sum(next_value,3);
        % take expectations over climate sensitivity nodes
        next_value = bsxfun(@times,next_value,weights_climsens);
        next_value = sum(next_value,2);
        % take expectations over consumption volatility
        next_value = weights_cons'*next_value;
        % calculate value by combining with current consumption pc
        v_start_perturb = ( (1-exp(-Params.discount))*(Params.pop(1)/totalpop)*cons_pc_now^(1-Params.time) +  exp(-Params.discount)*next_value^((1-Params.time)/(1-Params.risk))  )^(1/(1-Params.time));
    end
    
    % calculate derivative wrt current consumption
    deriv_C = (v_start^Params.time)*(1-exp(-Params.discount))*(cons_pc_now^(-Params.time))/totalpop;
    
    % cannot calculate endogenous scc
    scc(:,index_otherloop) = NaN;
    
else
    
    % scc endog CO2: calculate using deriv of time 1 value and do the division by marginal
    % welfare
    % note: these next loops are now trivial
    for index_node_outer_dam = 1:quadnodes_outer_dam
        for index_node_outer_warm = 1:quadnodes_outer_warm
            
            fspace_now = fundefn(Params.basistype,Params.basiscoeffs,statemin(1,:,index_node_outer_warm,index_node_outer_dam),statemax(1,:,index_node_outer_warm,index_node_outer_dam));
            
            if Params.learn_ez==1
                value_next(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(2,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state);
            end
            
            if Params.uselogstates==1
                deriv_M1(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state, [0 1 0 0])/start_state(2);
                deriv_M2(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state, [0 0 1 0])/start_state(3);
                deriv_C(index_node_outer_warm,index_node_outer_dam) = (funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state)^Params.time)*(1-exp(-Params.discount))*(exp(start_state(1))^(-Params.time))/totalpop;
            else
                deriv_M1(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state, [0 1 0 0]);
                deriv_M2(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state, [0 0 1 0]);
                deriv_C(index_node_outer_warm,index_node_outer_dam) = (funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state)^Params.time)*(1-exp(-Params.discount))*(start_state(1)^(-Params.time))/totalpop;
            end
            
        end
    end
    scc(:,index_otherloop) = -(Params.M1_frac*deriv_M1 + Params.M2_frac*deriv_M2)./deriv_C;
    scc(:,index_otherloop) = scc(:,index_otherloop)*1e-9/Params.co2_per_c; %convert units from Gt C to tCO2
    
    v_start(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state);
    v_start_perturb(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs_perturb(1,:,index_node_outer_warm,index_node_outer_dam)),fspace_now,start_state);
    
end

% scc exog co2
if Params.perturbation>0            
    scc_exogCO2(:,index_otherloop) = (( v_start - v_start_perturb )/Params.perturbation)./deriv_C;
    scc_exogCO2(:,index_otherloop) = scc_exogCO2(:,index_otherloop)*1e-9/Params.co2_per_c; %convert units from Gt C to tCO2
else
    scc_exogCO2(:,index_otherloop) = [];
end


%% Test value function approximant

if Params.learn_ez~=1
    
    % Options
    time_plot = [1 10 50 100 max_time-5:1:max_time-1]; % periods to plot
    if Params.learn_ez==0
        % can't change these
        plot_node_warm = 1;
        plot_node_damage = 1;
    else
        % can change these
        plot_node_warm = 1;
        plot_node_damage = 1;
    end
    
    for index_time = 1:length(time_plot)
        
        % Set up
        if time_plot(index_time)==1
            plot_state_fixed = start_state;
        else
            plot_state_fixed = mean ( [statemin(time_plot(index_time),:,1,1);statemax(time_plot(index_time),:,1,1)], 1);
        end
        fspace_plot = fundefn(Params.basistype,Params.basiscoeffs,statemin(time_plot(index_time),:,1,1),statemax(time_plot(index_time),:,1,1));
        fspace_plot_nodes = fundefn(Params.nodetype,Params.nodes,statemin(time_plot(index_time),:,1,1),statemax(time_plot(index_time),:,1,1));
        nodes = funnode(fspace_plot_nodes);
        statematrix_plot = gridmake(nodes); % each row defines one point, where the columns give [log_cons_pc M1 M2 T]
        
        % Make univariate plots
        h = figure('name', ['Time ' num2str(time_plot(index_time)) ' Value Functs, ' mat2str(Params.basiscoeffs)]);
        for index=1:4
            subplot(2,2,index);
            step = (statemax(time_plot(index_time),index,1,1)-statemin(time_plot(index_time),index,1,1))/100;
            plot_state = repmat(plot_state_fixed,101,1);
            plot_state(:,index) = transpose([statemin(time_plot(index_time),index,1,1):step:statemax(time_plot(index_time),index,1,1)]);
            plot_v = funeval(transpose(v_coeffs(time_plot(index_time),:,plot_node_warm,plot_node_damage)),fspace_plot,plot_state);
            plot(plot_state(:,index),plot_v);
            if Params.uselogstates==1
                switch index
                    case 1
                        xlabel('Log Cons p.c.');
                    case 2
                        xlabel('Log M1');
                    case 3
                        xlabel('Log M2');
                    case 4
                        xlabel('T');
                end
            else
                switch index
                    case 1
                        xlabel('Cons p.c.');
                    case 2
                        xlabel('M1');
                    case 3
                        xlabel('M2');
                    case 4
                        xlabel('T');
                end
            end
            hold on;
            state_markers = unique(statematrix_plot(:,index));
            plot_state2 = repmat(plot_state_fixed,length(state_markers),1);
            plot_state2(:,index) = state_markers;
            plot_v2 = funeval(transpose(v_coeffs(time_plot(index_time),:,plot_node_warm,plot_node_damage)),fspace_plot,plot_state2);
            scatter(plot_state2(:,index),plot_v2);
        end
        saveas(h,['valuefunct_time' num2str(time_plot(index_time)) '_coeffs' mat2str(Params.basiscoeffs) '_nodes' mat2str(Params.nodes) '.jpg']);
        close(h);
        
    end
    
end

%% For testing: Evaluate value function along given trajectory

if Params.learn_ez~=1
    
    evaluate_state = 0.5*(statemin(:,:,1,1) + statemax(:,:,1,1));
    
    for index_time = 1:max_time-1
        
        fspace_evaluate = fundefn(Params.basistype,Params.basiscoeffs,statemin(index_time,:,1,1),statemax(index_time,:,1,1));
        if Params.learn_ez==1
            for index_node_outer_warm = 1:quadnodes_outer_warm
                for index_node_outer_dam = 1:quadnodes_outer_dam
                    temp_value(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(index_time,:,index_node_outer_warm,index_node_outer_dam)),fspace_evaluate,evaluate_state(index_time,:));
                end
            end
            temp_value = bsxfun(@times,temp_value,weights_damage');
            temp_value = sum(temp_value,2);
            % take expectations over climate sensitivity
            evaluate_value(index_time,1) = weights_climsens'*temp_value;
        else
            evaluate_value(index_time,1) = funeval(transpose(v_coeffs(index_time,:,1,1)),fspace_evaluate,evaluate_state(index_time,:));
        end
        
    end
    
    if Params.perturbation>0
        for index_time = 1:max_time-1
            
            fspace_evaluate = fundefn(Params.basistype,Params.basiscoeffs,statemin(index_time,:,1,1),statemax(index_time,:,1,1));
            if Params.learn_ez==1
                for index_node_outer_warm = 1:quadnodes_outer_warm
                    for index_node_outer_dam = 1:quadnodes_outer_dam
                        temp_value(index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs_perturb(index_time,:,index_node_outer_warm,index_node_outer_dam)),fspace_evaluate,evaluate_state(index_time,:));
                    end
                end
                temp_value = bsxfun(@times,temp_value,weights_damage');
                temp_value = sum(temp_value,2);
                % take expectations over climate sensitivity
                evaluate_value(index_time,2) = weights_climsens'*temp_value;
            else
                evaluate_value(index_time,2) = funeval(transpose(v_coeffs_perturb(index_time,:,1,1)),fspace_evaluate,evaluate_state(index_time,:));
            end
            
        end
    end
    
end

%% For testing: Evaluate value function at nodes in some time period

if Params.learn_ez~=1
    
    index_time = max_time-1;
    
    fspace_evaluate = fundefn(Params.basistype,Params.basiscoeffs,statemin(index_time,:,1,1),statemax(index_time,:,1,1));
    fspace_evaluate_nodes = fundefn(Params.nodetype,Params.nodes,statemin(index_time,:,1,1),statemax(index_time,:,1,1));
    nodes = funnode(fspace_evaluate_nodes);
    statematrix_test = gridmake(nodes); % each row defines one point, where the columns give [cons_pc M1 M2 T]
    
    if Params.learn_ez==1
        for index_node_outer_warm = 1:quadnodes_outer_warm
            for index_node_outer_dam = 1:quadnodes_outer_dam
                temp_value(:,index_node_outer_warm,index_node_outer_dam) = funeval(transpose(v_coeffs(index_time,:,index_node_outer_warm,index_node_outer_dam)),fspace_evaluate,statematrix_test(index_node,:));
            end
        end
        % take expectations over damages
        temp_value = bsxfun(@times,temp_value,reshape(weights_damage,[1 1 length(weights_damage)]));
        temp_value = sum(temp_value,3);
        % take expectations over climate sensitivity
        temp_value = bsxfun(@times,temp_value,reshape(weights_climsens,[1 length(weights_climsens)]));
        value_node = sum(temp_value,2);
    else
        value_node = funeval(transpose(v_coeffs(index_time,:,1,1)),fspace_evaluate,statematrix_test);
    end

end




%% Saving

if index_otherloop == length(otherloop_vector)
    save(['asset_results.mat']);
else
    save(['asset_results' num2str(index_otherloop) '.mat']);
end
if index_otherloop > 1
    delete(['asset_results' num2str(index_otherloop-1) '.mat']);
end